-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training improvements #17
Training improvements #17
Conversation
…cls to training_args
iit/model_pairs/base_model_pair.py
Outdated
@@ -18,6 +20,24 @@ | |||
from iit.utils.index import Ix, TorchIndex | |||
from iit.utils.metric import MetricStoreCollection, MetricType | |||
|
|||
def in_notebook() -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can do this by just importing tqdm: much cleaner that way. (at least according to this)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried using just tqdm, but it definitely didn't work in notebook mode. I think I hunted down all of the print statements, too.
Cleaner compromise than what's here now: moved this block to utils/tqdm.py, and added from iit.utils.tqdm import tqdm.
iit/model_pairs/base_model_pair.py
Outdated
@@ -177,7 +197,7 @@ def get_IIT_loss_over_batch( | |||
hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node) | |||
label_idx = self.get_label_idxs() | |||
# IIT loss is only computed on the tokens we care about | |||
loss = loss_fn(ll_output[label_idx.as_index], hl_output[label_idx.as_index]) | |||
loss = loss_fn(ll_output[label_idx.as_index].to(hl_output.device), hl_output[label_idx.as_index]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably just raise if dataset, hl_model and ll_model aren't on the same device during init/starting training. This usually just hides the main problem and makes it harder to find bugs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, I'll add an assert to the beginning of train() and remove all of these.
iit/model_pairs/base_model_pair.py
Outdated
|
||
if early_stop and self._check_early_stop_condition(test_metrics): | ||
break | ||
epoch_pbar.update(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nicer if we can move the entire logic to _print_and_log_metrics
. current_epoch_log
can remain there. And logging it to wandb might be useful as well!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this logic to _print_and_log_metrics. I think everything that makes up the string is already being logged to wandb.
iit/model_pairs/base_model_pair.py
Outdated
for metric in metrics: | ||
print(metric, end=", ") | ||
if metric.type == MetricType.ACCURACY: | ||
current_epoch_log += f"{metric.get_name()}: {metric.get_value():.2f}, " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
str(metric) does this automatically
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to current_epoch_log += str(metric) + ", "
@@ -21,14 +21,9 @@ def __init__( | |||
training_args: dict = {} | |||
): | |||
default_training_args = { | |||
"batch_size": 256, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be really helpful if we could maintain the default args as they were before. Or at least store the default hyperparams we used before in some config for reproducibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all the defaults are preserved (they were just set in multiple places) EXCEPT I did change use_single_loss and optimizer_kwargs. I'll change those back to the defaults from before.
"strict_weight": 1.0, | ||
"clip_grad_norm": 1.0, | ||
"strict_weight_schedule" : lambda s, i: s, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a cool idea!
Maybe it is better to implement it as
@property
def strict_weight_at_epoch(self):
return self.training_args.strict_weight_schedule(<args_from_self>)
Instead of changing the strict weight variable after each epoch? (or a method like strict_weight_for_epoch = self.get_scheduled_strict_weight()
and then calculate the loss).
This lambda is also throwing me off a bit, maybe renaming the args will make it clearer...
It also seems like this is achievable by using different optimisers/lrs for each loss (and not using single loss)? No idea which one's better though...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, unclear to me which is the right way to go right now. I think it's best to remove it (right now it's not doing anything) and if you find that this is a useful idea down the road you can add it how you see fit?
iit_loss = 0 | ||
ll_loss = 0 | ||
behavior_loss = 0 | ||
iit_loss = t.zeros(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not completely sure why this is needed- You can usually add floats and tensors without messing up the grads, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for mypy type-checking. step_on_loss expects a Tensor instead of a float and the .item() call at the end is a type error if this isn't a Tensor.
I think the right way to resolve this is to remove the if isintance(Tensor) logic at the end of the function, since it's now always a tensor. I'll do that.
iit/model_pairs/base_model_pair.py
Outdated
@@ -8,6 +9,7 @@ | |||
from torch.utils.data import DataLoader | |||
from tqdm import tqdm # type: ignore | |||
from transformer_lens.hook_points import HookedRootModule, HookPoint # type: ignore | |||
from IPython.display import clear_output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed here? Don't think it is being used...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not! Good catch, that was leftover from getting tqdm stuff working.
) | ||
|
||
# Set seed before iterating on loaders for reproduceablility. | ||
t.manual_seed(training_args["seed"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to use a generator for loaders like we do for numpy? I think I used to set this once globally in the training script before- my bad. :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not totally sure? I got this solution here. It seems like the random operation is set when the dataloader is turned into an iterable, and someone could use torch functions between initializing and training the model pair, which could hinder reproducibility without putting something here.
It doesn't look particularly problematic. I'll have a more careful look in a while. Thanks for adding these! Will definitely check my cases though. This seems important in general- somehow I can't reproduce the 4 new trained cases after pulling the newer PRs. Maybe these changes help. :") |
OK! I think I responded to all of your changes and pushed updates. Also found a problem that was causing circuits-bench tests to fail and fixed that so they all are passing on my end. I'll be offline for the next three weeks starting in a few hours, so if there are other problems / stylistic things, please feel free to edit the branch of my repo / this PR to get those fixed! |
Also I just added back in one .to(device) in the eval() step. It's really helpful for me to not have my entire dataset on cuda / mps, especially when training successive models in a notebook, so putting the dataset labels on the model's device in run_eval_step is helpful. |
Great! The changes look fine now. Merging. Thanks for the PR! |
A bunch of small changes to make training smoother & a bit more robust; most importantly:
mypy type checking and pytest tests/ passes, but it's possible some downstream stuff broke? Unclear.